{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.insert(0, '../../../Utilities/')\n",
    "import argparse\n",
    "import os\n",
    "import torch\n",
    "from collections import OrderedDict\n",
    "import math\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import scipy.io\n",
    "from scipy.interpolate import griddata\n",
    "from plotting import newfig, savefig\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "import matplotlib.gridspec as gridspec\n",
    "import torchvision.transforms as transforms\n",
    "from torchvision.utils import save_image\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "from torchvision import datasets\n",
    "from torch.autograd import Variable\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch\n",
    "import seaborn as sns\n",
    "import pylab as py\n",
    "import time\n",
    "from pyDOE import lhs\n",
    "import warnings\n",
    "sys.path.insert(0, '../../../Scripts/')\n",
    "from models_pde import Generator, Discriminator, Q_Net\n",
    "from pig import *\n",
    "# from ../Scripts/helper import *\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "np.random.seed(1234)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# CUDA support \n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device('cuda:1')\n",
    "else:\n",
    "    device = torch.device('cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hyper-parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 50000\n",
    "lambda_val = 2\n",
    "lambda_q = 0.5\n",
    "\n",
    "noise = 0.1\n",
    "\n",
    "\n",
    "#architecture for the models\n",
    "d_hid_dim = 100 \n",
    "d_num_layer = 3\n",
    "\n",
    "g_hid_dim = 100\n",
    "g_num_layer = 4\n",
    "\n",
    "q_hid_dim = 50\n",
    "q_num_layer = 4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "    \n",
    "# Doman bounds\n",
    "lb = np.array([-5.0, 0.0])\n",
    "ub = np.array([5.0, np.pi/2])\n",
    "\n",
    "N0 = 50\n",
    "N_b = 50\n",
    "N_f = 20000\n",
    "# layers = [2, 100, 100, 100, 100, 2]\n",
    "\n",
    "data = scipy.io.loadmat('../../../datasets/NLS.mat')\n",
    "\n",
    "t = data['tt'].flatten()[:,None]\n",
    "x = data['x'].flatten()[:,None]\n",
    "Exact = data['uu']\n",
    "Exact_u = np.real(Exact)\n",
    "Exact_v = np.imag(Exact)\n",
    "Exact_h = np.sqrt(Exact_u**2 + Exact_v**2)\n",
    "\n",
    "X, T = np.meshgrid(x,t)\n",
    "\n",
    "X_star = np.hstack((X.flatten()[:,None], T.flatten()[:,None]))\n",
    "u_star = Exact_u.T.flatten()[:,None]\n",
    "v_star = Exact_v.T.flatten()[:,None]\n",
    "h_star = Exact_h.T.flatten()[:,None]\n",
    "\n",
    "###########################\n",
    "\n",
    "idx_x = np.random.choice(x.shape[0], N0, replace=False)\n",
    "x0 = x[idx_x,:]\n",
    "u0 = Exact_u[idx_x,0:1]\n",
    "v0 = Exact_v[idx_x,0:1]\n",
    "\n",
    "idx_t = np.random.choice(t.shape[0], N_b, replace=False)\n",
    "tb = t[idx_t,:]\n",
    "\n",
    "X_f = lb + (ub-lb)*lhs(2, N_f)\n",
    "\n",
    "X0 = np.concatenate((x0, 0*x0), 1) # (x0, 0)\n",
    "Y0 = np.concatenate((u0,v0), 1) \n",
    "X_lb = np.concatenate((0*tb + lb[0], tb), 1) # (lb[0], tb)\n",
    "X_ub = np.concatenate((0*tb + ub[0], tb), 1) # (ub[0], tb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "D = Discriminator(in_dim = 4, out_dim = 1, hid_dim = d_hid_dim, num_layers = d_num_layer).to(device)\n",
    "G = Generator(in_dim = 3, out_dim = 2, hid_dim = g_hid_dim, num_layers = g_num_layer).to(device)\n",
    "Q = Q_Net(in_dim = 4, out_dim = 1, hid_dim = q_hid_dim, num_layers = q_num_layer).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PIG = Schrodinger_PIG(X0, Y0, X_f, X_lb, X_ub, X_star, h_star, G, D, Q, device, num_epochs, lambda_val, noise)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PIG.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nsamples = 500\n",
    "u_pred_list = []\n",
    "v_pred_list = []\n",
    "h_pred_list = []\n",
    "f_u_pred_list = []\n",
    "f_v_pred_list = []\n",
    "for run in range(nsamples):\n",
    "    u_pred, v_pred, f_u_pred, f_v_pred = PIG.predict(X_star)\n",
    "    h_pred = np.sqrt(u_pred**2 + v_pred**2)\n",
    "    u_pred_list.append(u_pred)\n",
    "    v_pred_list.append(v_pred)\n",
    "    f_u_pred_list.append(f_u_pred)\n",
    "    f_v_pred_list.append(f_v_pred)\n",
    "    h_pred_list.append(h_pred)\n",
    "\n",
    "\n",
    "u_pred_arr = np.array(u_pred_list)\n",
    "v_pred_arr = np.array(v_pred_list)\n",
    "f_u_pred_arr = np.array(f_u_pred_list)\n",
    "f_v_pred_arr = np.array(f_v_pred_list)\n",
    "h_pred_arr = np.array(h_pred_list)\n",
    "\n",
    "u_pred = u_pred_arr.mean(axis=0)\n",
    "v_pred = v_pred_arr.mean(axis=0)\n",
    "f_u_pred = f_u_pred_arr.mean(axis=0)\n",
    "f_v_pred = f_v_pred_arr.mean(axis=0)\n",
    "h_pred = h_pred_arr.mean(axis=0)\n",
    "\n",
    "h_pred_var = h_pred_arr.var(axis=0)\n",
    "residual = (f_u_pred**2).mean() + (f_v_pred**2).mean()\n",
    "\n",
    "#     u_dev = u_pred_arr.var(axis=0)\n",
    "#     f_dev = f_pred_arr.var(axis=0)\n",
    "\n",
    "error_u = np.linalg.norm(u_star-u_pred,2)/np.linalg.norm(u_star,2)\n",
    "error_v = np.linalg.norm(v_star-v_pred,2)/np.linalg.norm(v_star,2)\n",
    "error_h = np.linalg.norm(h_star-h_pred,2)/np.linalg.norm(h_star,2)\n",
    "\n",
    "\n",
    "U_pred = griddata(X_star, u_pred.flatten(), (X, T), method='cubic')\n",
    "V_pred = griddata(X_star, v_pred.flatten(), (X, T), method='cubic')\n",
    "H_pred = griddata(X_star, h_pred.flatten(), (X, T), method='cubic')\n",
    "H_pred_var = griddata(X_star, h_pred_var.flatten(), (X, T), method='cubic')\n",
    "\n",
    "FU_pred = griddata(X_star, f_u_pred.flatten(), (X, T), method='cubic')\n",
    "FV_pred = griddata(X_star, f_v_pred.flatten(), (X, T), method='cubic')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Error u:\", error_u)\n",
    "print(\"Error v:\", error_v)\n",
    "print(\"Error h:\", error_h)                   \n",
    "print('Residual: %e' % (residual))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "######################################################################\n",
    "############################# Plotting ###############################\n",
    "######################################################################    \n",
    "t = data['tt'].flatten()[:,None]\n",
    "x = data['x'].flatten()[:,None]\n",
    "\n",
    "X0 = np.concatenate((x0, 0*x0), 1) # (x0, 0)\n",
    "X_lb = np.concatenate((0*tb + lb[0], tb), 1) # (lb[0], tb)\n",
    "X_ub = np.concatenate((0*tb + ub[0], tb), 1) # (ub[0], tb)\n",
    "X_u_train = np.vstack([X0, X_lb, X_ub])\n",
    "\n",
    "# fig, ax = newfig(1.0, 0.9)\n",
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "ax.axis('off')\n",
    "\n",
    "####### Row 0: h(t,x) ##################    \n",
    "gs0 = gridspec.GridSpec(1, 2)\n",
    "gs0.update(top=1-0.06, bottom=1-1/3, left=0.15, right=0.85, wspace=0)\n",
    "ax = plt.subplot(gs0[:, :])\n",
    "\n",
    "h = ax.imshow(H_pred.T, interpolation='nearest', cmap='YlGnBu', \n",
    "              extent=[lb[1], ub[1], lb[0], ub[0]], \n",
    "              origin='lower', aspect='auto')\n",
    "divider = make_axes_locatable(ax)\n",
    "cax = divider.append_axes(\"right\", size=\"5%\", pad=0.05)\n",
    "fig.colorbar(h, cax=cax)\n",
    "\n",
    "ax.plot(X_u_train[:,1], X_u_train[:,0], 'kx', label = 'Data (%d points)' % (X_u_train.shape[0]), markersize = 4, clip_on = False)\n",
    "\n",
    "line = np.linspace(x.min(), x.max(), 2)[:,None]\n",
    "ax.plot(t[75]*np.ones((2,1)), line, 'k--', linewidth = 1)\n",
    "ax.plot(t[100]*np.ones((2,1)), line, 'k--', linewidth = 1)\n",
    "ax.plot(t[125]*np.ones((2,1)), line, 'k--', linewidth = 1)    \n",
    "\n",
    "ax.set_xlabel('$t$')\n",
    "ax.set_ylabel('$x$')\n",
    "leg = ax.legend(frameon=False, loc = 'best')\n",
    "#    plt.setp(leg.get_texts(), color='w')\n",
    "ax.set_title('$|h(t,x)|$', fontsize = 10)\n",
    "\n",
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "ax.axis('off')\n",
    "\n",
    "####### Row 0: h(t,x) ##################    \n",
    "gs0 = gridspec.GridSpec(1, 2)\n",
    "gs0.update(top=1-0.06, bottom=1-1/3, left=0.15, right=0.85, wspace=0)\n",
    "ax = plt.subplot(gs0[:, :])\n",
    "\n",
    "h = ax.imshow(H_pred_var.T, interpolation='nearest', cmap='YlGnBu', \n",
    "              extent=[lb[1], ub[1], lb[0], ub[0]], \n",
    "              origin='lower', aspect='auto')\n",
    "divider = make_axes_locatable(ax)\n",
    "cax = divider.append_axes(\"right\", size=\"5%\", pad=0.05)\n",
    "fig.colorbar(h, cax=cax)\n",
    "\n",
    "ax.plot(X_u_train[:,1], X_u_train[:,0], 'kx', label = 'Data (%d points)' % (X_u_train.shape[0]), markersize = 4, clip_on = False)\n",
    "\n",
    "line = np.linspace(x.min(), x.max(), 2)[:,None]\n",
    "ax.plot(t[75]*np.ones((2,1)), line, 'k--', linewidth = 1)\n",
    "ax.plot(t[100]*np.ones((2,1)), line, 'k--', linewidth = 1)\n",
    "ax.plot(t[125]*np.ones((2,1)), line, 'k--', linewidth = 1)    \n",
    "\n",
    "ax.set_xlabel('$t$')\n",
    "ax.set_ylabel('$x$')\n",
    "leg = ax.legend(frameon=False, loc = 'best')\n",
    "#    plt.setp(leg.get_texts(), color='w')\n",
    "ax.set_title('$Variance |h(t,x)|$', fontsize = 10)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "####### Row 1: h(t,x) slices ##################    \n",
    "fig = plt.figure(figsize=(40, 20))\n",
    "ax = fig.add_subplot(111)\n",
    "\n",
    "gs1 = gridspec.GridSpec(1, 4)\n",
    "gs1.update(top=1-1/3, bottom=0, left=0.1, right=0.9, wspace=0.5)\n",
    "\n",
    "ax = plt.subplot(gs1[0, 0])\n",
    "ax.plot(x,Exact_h[:,75], 'b-', linewidth = 2, label = 'Exact')       \n",
    "ax.plot(x,H_pred[75,:], 'r--', linewidth = 2, label = 'Prediction')\n",
    "lower = H_pred[75,:] - 2.0*np.sqrt(H_pred_var[75,:])\n",
    "upper = H_pred[75,:] + 2.0*np.sqrt(H_pred_var[75,:])\n",
    "plt.fill_between(x.flatten(), lower.flatten(), upper.flatten(), \n",
    "                 facecolor='orange', alpha=0.5, label=\"Two std band\")\n",
    "ax.set_xlabel('$x$')\n",
    "ax.set_ylabel('$|h(t,x)|$')    \n",
    "ax.set_title('$t = %.2f$' % (t[75]), fontsize = 10)\n",
    "ax.axis('square')\n",
    "ax.set_xlim([-5.1,5.1])\n",
    "ax.set_ylim([-0.1,5.1])\n",
    "\n",
    "ax = plt.subplot(gs1[0, 1])\n",
    "ax.plot(x,Exact_h[:,100], 'b-', linewidth = 2, label = 'Exact')       \n",
    "ax.plot(x,H_pred[100,:], 'r--', linewidth = 2, label = 'Prediction')\n",
    "lower = H_pred[100,:] - 2.0*np.sqrt(H_pred_var[100,:])\n",
    "upper = H_pred[100,:] + 2.0*np.sqrt(H_pred_var[100,:])\n",
    "plt.fill_between(x.flatten(), lower.flatten(), upper.flatten(), \n",
    "                 facecolor='orange', alpha=0.5, label=\"Two std band\")\n",
    "ax.set_xlabel('$x$')\n",
    "ax.set_ylabel('$|h(t,x)|$')\n",
    "ax.axis('square')\n",
    "ax.set_xlim([-5.1,5.1])\n",
    "ax.set_ylim([-0.1,5.1])\n",
    "ax.set_title('$t = %.2f$' % (t[100]), fontsize = 10)\n",
    "ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.8), ncol=5, frameon=False)\n",
    "\n",
    "ax = plt.subplot(gs1[0, 2])\n",
    "ax.plot(x,Exact_h[:,125], 'b-', linewidth = 2, label = 'Exact')       \n",
    "ax.plot(x,H_pred[125,:], 'r--', linewidth = 2, label = 'Prediction')\n",
    "lower = H_pred[125,:] - 2.0*np.sqrt(H_pred_var[125,:])\n",
    "upper = H_pred[125,:] + 2.0*np.sqrt(H_pred_var[125,:])\n",
    "plt.fill_between(x.flatten(), lower.flatten(), upper.flatten(), \n",
    "                 facecolor='orange', alpha=0.5, label=\"Two std band\")\n",
    "ax.set_xlabel('$x$')\n",
    "ax.set_ylabel('$|h(t,x)|$')\n",
    "ax.axis('square')\n",
    "ax.set_xlim([-5.1,5.1])\n",
    "ax.set_ylim([-0.1,5.1])    \n",
    "ax.set_title('$t = %.2f$' % (t[125]), fontsize = 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
